import io
import math
import os
from flask import Flask, request, send_file, jsonify
from flask_cors import CORS
from PIL import Image, ImageFile
import numpy as np
import time
import logging
import u2net
import uuid
import base64
import re

logging.basicConfig(level=logging.INFO)

# Initialize the Flask application
app = Flask(__name__)
CORS(app)


@app.route('/api/hello', methods=['GET'])
def hello():
    return '<h1>Background Removal API</h1>'


# Route http posts to this method
@app.route('/api/removebg', methods=['POST'])
def removebg():
    start = time.time()
    img = base64_to_image(request.form.get("base64Data"))
    # Convert string data to PIL Image
    # img = Image.open(io.BytesIO(data))
    # Ensure i,qge size is under 1024
    if img.size[0] > 1024 or img.size[1] > 1024:
        img.thumbnail((1024, 1024))

    # Process Image
    mask = u2net.run(np.array(img)).convert("L")
    bgColor = (255, 255, 255, 255)
    if (request.form.get("backgroundColor") == '蓝色'):
        bgColor = (71, 151, 221, 255)
    elif (request.form.get("backgroundColor") == '红色'):
        bgColor = (200, 44, 59, 255)

    img_bg_removed = naive_cutout(img, mask, bgColor)
    img_path = save_img(img_bg_removed)
    logging.info(f'Completed in {time.time() - start:.2f}s')
    full_path = "http://localhost:8081/api/downloadImg?imgName=" + img_path
    return jsonify({'code': 200, 'imgPath': full_path, 'imgName': img_path})


@app.route('/api/downloadImg', methods=['GET'])
def download_img():
    imgName = request.args.get('imgName')
    full_path = os.path.join(os.getcwd(), 'output_img', imgName)
    target_file = os.path.join(os.getcwd(), 'output_img', str(uuid.uuid4()) + ".jpg")
    maxBytes = float(request.args.get("maxBytes"))
    file_size = os.path.getsize(full_path) // 1024
    if file_size > maxBytes:
        compress_image(full_path, target_file, maxBytes)
        file_data = send_file(target_file, as_attachment=True)
        os.remove(target_file)
        return file_data
    else:
        return send_file(full_path, as_attachment=True)


# 压缩图片文件
def compress_image(imagefile, targetfile, targetsize):
    currentsize = math.ceil(os.path.getsize(imagefile) / 1024)
    for quality in range(99, 10, -5):  # 压缩质量递减
        if currentsize > targetsize:
            image = Image.open(imagefile)
            image.save(targetfile, optimize=True, quality=quality)
            currentsize = math.ceil(os.path.getsize(targetfile) / 1024)

def base64_to_image(base64_str):
    base64_data = re.sub('^data:image/.+;base64,', '', base64_str)
    byte_data = base64.b64decode(base64_data)
    image_data = io.BytesIO(byte_data)
    img = Image.open(image_data)
    return img


def save_img(img):
    output_img = os.path.join(os.getcwd(), 'output_img')
    if not os.path.exists(output_img):
        os.makedirs(output_img, exist_ok=True)
    img_name = str(uuid.uuid4()) + ".jpg"
    img_path = output_img + os.sep + img_name
    img.save(img_path)
    return img_name


def naive_cutout(img, mask, backgroundColor=(255, 255, 255, 255)):
    empty = Image.new("RGBA", (img.size), backgroundColor)
    cutout = Image.composite(img, empty, mask.resize(img.size, Image.LANCZOS))
    return cutout.convert('RGB')


if __name__ == '__main__':
    os.environ['FLASK_ENV'] = 'development'
    port = int(os.environ.get('PORT', 8081))
    app.run(debug=True, host='localhost', port=port)
